/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_SCATTER_H_
#define TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_SCATTER_H_

#include <cstdint>

#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/SmallVector.h"
#include "mlir/IR/BuiltinAttributes.h"  // from @llvm-project
#include "mlir/IR/BuiltinTypeInterfaces.h"  // from @llvm-project
#include "mlir/IR/Operation.h"  // from @llvm-project
#include "mlir/Support/LLVM.h"  // from @llvm-project
#include "mlir/Support/LogicalResult.h"  // from @llvm-project
#include "mlir/Transforms/DialectConversion.h"  // from @llvm-project
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h"

namespace mlir {
namespace odml {

// Convert updates into canonical form as expected by tf.scatter ops.
//
// tf.scatter expects `update_window_dims` to be the trailing dimensions.
//
// To support scatter ops generated by numpy-like slice updates:
//   nd_array[:, [i,j]] = [i_values, j_values]
//
// `updates` must be transposed when the update_window_dims are the leading
// dimensions of `updates`.
//
// Other values of `update_window_dims` are left unsupported.
//
// Eg 1. An update in canonical form:
//  * indices shape(A,B,C)
//  * updates shape(A,B,D,E,F)
// Then:
//  * D,E,F are the update window dims [2,3,4]
//  * C is the index vector dimension
//  * A,B iterate over the updates and indices
//
// If `update_window_dims` are not the trailing dimensions then updates must be
// transposed.
//
// Eg 2. An update in non-canonical form:
//  * indices shape(a,b,c)
//  * updates shape(d,e,f,a,b)
// Then:
//  * d,e,f are the update window dims [0,1,2]
//  * c is the index vector dimension
//  * a,b iterate over the updates and indices
//
//  The update needs permuting to be in the form (a,b,d,e,f) so that the update
//  window dims are the trailing dimensions.
//
// To canonicalize the updates above, replace the updates with:
//   transpose(updates, permutation={3,4,0,1,2})
//
// Note: NormalizeIndexVector is assumed to have run on the indices already so
// that the index_vector_dim is the trailing dimension in `indices`.
LogicalResult CanonicalizeScatterUpdates(
    Operation* scatter_op, llvm::ArrayRef<int64_t> update_window_dims,
    const Value& indices, const ShapedType& indices_type, Value& updates,
    ShapedType& updates_type, ConversionPatternRewriter& rewriter);

template <typename BinaryOp, typename TfOp>
class ConvertScatterOp : public OpConversionPattern<mhlo::ScatterOp> {
 public:
  using OpConversionPattern::OpConversionPattern;

  LogicalResult matchAndRewrite(
      mhlo::ScatterOp scatter_op, OpAdaptor adaptor,
      ConversionPatternRewriter& rewriter) const final;
};

using ConvertScatterAddOp =
    ConvertScatterOp<mhlo::AddOp, TF::TensorScatterAddOp>;
using ConvertScatterMaxOp =
    ConvertScatterOp<mhlo::MaxOp, TF::TensorScatterMaxOp>;
using ConvertScatterMinOp =
    ConvertScatterOp<mhlo::MinOp, TF::TensorScatterMinOp>;
using ConvertScatterSubOp =
    ConvertScatterOp<mhlo::SubtractOp, TF::TensorScatterSubOp>;
using ConvertScatterUpdateOp =
    ConvertScatterOp<void, TF::TensorScatterUpdateOp>;

template class ConvertScatterOp<mhlo::AddOp, TF::TensorScatterAddOp>;
template class ConvertScatterOp<mhlo::MaxOp, TF::TensorScatterMaxOp>;
template class ConvertScatterOp<mhlo::MinOp, TF::TensorScatterMinOp>;
template class ConvertScatterOp<mhlo::SubtractOp, TF::TensorScatterSubOp>;
template class ConvertScatterOp<void, TF::TensorScatterUpdateOp>;

}  // end namespace odml
}  // end namespace mlir

#endif  // TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_SCATTER_H_
